# ======================================================================
# Copyright CERFACS (February 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

r"""This module implement several methods related to the quantum wave equation solver.

The main function is :func:`~evolve_1D_dirichlet` but other function can be used to
have more control on the generated circuits, such as
:func:`~evolve_1D_dirichlet_no_repetition`.
"""
import logging
import typing as ty

import numpy
from qat.lang.AQASM.routines import QRoutine
from qat.lang.AQASM.misc import build_gate

from qaths.applications.wave_equation.time_adaptation import (
    adapt_dirichlet_evolution_time,
)
from qaths.applications.wave_equation.utils import (
    compute_qubit_number_from_considered_points_1d,
)
from qaths.simulation.base.integer_weighted import (
    simulate_signed_integer_weighted_hamiltonian,
)
from qaths.simulation.pf.repetition_computation import compute_r
from qaths.simulation.pf.trotter import simulate_using_trotter
from qaths.utils.math import normalise, next_power_of_2
from qaths.utils.simulation import simulate_routine
from qaths.applications.wave_equation.linking_sets import basic

logger = logging.getLogger("qaths.applications.wave_equation.evolve_1D_dirichlet")

try:
    from qat.linalg.oracles import StatePreparation

    @build_gate(
        "initialise_1d_dirichlet_stationary",
        [numpy.ndarray],
        arity=lambda arr: int(numpy.ceil(numpy.log2(2 * arr.size - 3))),
    )
    def initialise_1d_dirichlet_stationary(initial_phi_v: numpy.ndarray) -> QRoutine:
        """Prepare an initial quantum state for a stationary initial condition.

        :param initial_phi_v: The initial state to give to the wave equation solver.
            The time derivative of the initial state is considered to be zero.
        :return: A quantum gate initialising the qubits to the desired quantum state.
        """
        assert numpy.isclose(
            initial_phi_v[0], 0
        ), "Imposed Dirichlet boundary conditions: phi(0) = 0"
        assert numpy.isclose(
            initial_phi_v[-1], 0
        ), "Imposed Dirichlet boundary conditions: phi(1) = 0"

        phi_v_size = initial_phi_v.size - 2
        phi_e_size = phi_v_size + 1

        # Temporary fix because qat seems to reshape the numpy array and to change
        # its type to complex
        assert numpy.linalg.norm(numpy.imag(initial_phi_v)) < 1e-10
        initial_phi_v = numpy.real(initial_phi_v.reshape((-1,)))

        rout = QRoutine()
        qubits = rout.new_wires(int(numpy.ceil(numpy.log2(phi_v_size + phi_e_size))))

        statevector = normalise(
            numpy.hstack(
                (
                    initial_phi_v[1:-1],
                    numpy.zeros(
                        (next_power_of_2(phi_v_size + phi_e_size) - phi_v_size)
                    ),
                )
            )
        )
        rout.apply(
            StatePreparation(statevector), qubits,
        )
        return rout


except ImportError:
    logger.warning(
        "StatePreparation oracle not found. You are probably using MyQLM "
        "that does not contain this feature. Disabling the "
        "initialise_1d_dirichlet_stationary function."
    )


@build_gate(
    "evolve_1d_dirichlet_no_repetition_no_time_adjustment",
    [float, int, int],
    arity=lambda t, discr, trotter: compute_qubit_number_from_considered_points_1d(
        discr - 2
    ),
)
def _evolve_1d_dirichlet_no_repetition_no_time_adjustment(
    time: float, discretisation_points_number: int, trotter_order: int = 1
) -> QRoutine:
    """Construct a quantum gate solving the wave equation.

    This function construct the gate that will solve the wave equation without dealing
    with repetitions and evolution time adaptations.

    :param time: The time for which we want to solve the wave equation.
    :param discretisation_points_number: The number of points in the 1D line
        discretisation.
    :param trotter_order: The order of the Trotter-Suzuki formula used. An higher
        order will lead to a more precise simulation but also a larger number of
        gates.
    :return: a quantum circuit evolving the given quantum state following the
        wave equation for a time t = time / (discretisation_points_number - 1).
    """
    considered_points = discretisation_points_number - 2
    n = compute_qubit_number_from_considered_points_1d(considered_points)

    # Hs entries have only 0 and 1 as **weights** (sign is handled by another qubit) and
    # so 1 qubit is enough to encode the weights
    int_size = 1

    generators = [
        lambda t: simulate_signed_integer_weighted_hamiltonian(
            # get_oracle_dirichlet1_1d_wave_equation(n, discretisation_points_number),
            "oracle1",
            n,
            int_size,
            t,
        ),
        lambda t: simulate_signed_integer_weighted_hamiltonian(
            # get_oracle_dirichlet2_1d_wave_equation(n, discretisation_points_number),
            "oracle2",
            n,
            int_size,
            t,
        ),
    ]

    gate = simulate_using_trotter(trotter_order, time, generators)
    rout = QRoutine()
    qubits = rout.new_wires(gate.arity)
    rout.apply(gate, qubits)
    return rout


@build_gate(
    "evolve_1d_dirichlet_no_repetition",
    [float, int, int],
    arity=lambda t, discr, trotter: compute_qubit_number_from_considered_points_1d(
        discr - 2
    ),
)
def evolve_1d_dirichlet_no_repetition(
    time: float, discretisation_points_number: int, trotter_order: int = 1
) -> QRoutine:
    """Construct a quantum gate solving the wave equation.

    :param time: The time for which we want to solve the wave equation.
    :param discretisation_points_number: The number of points in the 1D line
        discretisation.
    :param trotter_order: The order of the Trotter-Suzuki formula used. An higher
        order will lead to a more precise simulation but also a larger number of
        gates.
    :return: a quantum circuit evolving the given quantum state following the
        wave equation.
    """
    simulation_time = adapt_dirichlet_evolution_time(time, discretisation_points_number)

    gate = _evolve_1d_dirichlet_no_repetition_no_time_adjustment(
        simulation_time, discretisation_points_number, trotter_order
    )
    rout = QRoutine()
    qubits = rout.new_wires(gate.arity)
    rout.apply(gate, qubits)
    return rout


@build_gate(
    "evolve_1d_dirichlet_no_time_adjustment",
    [float, int, float, int],
    arity=lambda t, discr, eps, trotter: compute_qubit_number_from_considered_points_1d(
        discr - 2
    ),
)
def evolve_1d_dirichlet_no_time_adjustment(
    time: float,
    discretisation_points_number: int,
    epsilon: float,
    trotter_order: int = 1,
    compute_repetitions=compute_r,
) -> QRoutine:
    """Construct a quantum gate solving the wave equation.

    :param time: The time for which we want to solve the wave equation.
    :param discretisation_points_number: The number of points in the 1D line
        discretisation.
    :param epsilon: Desired precision.
    :param trotter_order: The order of the Trotter-Suzuki formula used. An higher
        order will lead to a more precise simulation but also a larger number of
        gates.
    :param compute_repetitions: A function taking the simulation time, the precision,
        the number of matrices in the sum-decomposition, the maximum spectral norm of
        the matrices in the sum-decomposition and the trotter-order used (in this order)
        and returning a number of repetitions r ensuring that the given precision is
        achieved.
    :return: a quantum circuit evolving the given quantum state following the
        wave equation.
    """
    # We need to compute the time for which we will simulate the Hamiltonian now
    # because this time is important to compute r. The evolve_1D_dirichlet_no_repetition
    # will still require the original time because it takes care of computing the real
    # simulation time.
    simulation_time = time

    # We have 2 Hamiltonians in our decomposition
    decomposition_size = 2
    # Each Hamiltonian is a unitary matrix of size m, completed by zeros to obtain
    # a size of 2**n. Because of this, each Hamiltonian have m non-zero eigenvalues
    # that are all of unit norm.
    # The spectral norm of each Hamiltonian is then 1.
    max_spectral_norm = 1.0

    r = compute_repetitions(
        simulation_time, epsilon, decomposition_size, max_spectral_norm, trotter_order
    )

    # We solve the wave equation for a time "time/r". As solving the wave equation for
    # a time "time" means simulating the Hamiltonian for a time "simulation_time", we
    # simulate the Hamiltonian for a time "simulation_time / r".
    gate = _evolve_1d_dirichlet_no_repetition_no_time_adjustment(
        simulation_time / r, discretisation_points_number, trotter_order
    )
    routine = QRoutine()
    qubits = routine.new_wires(gate.arity)
    for _ in range(r):
        routine.apply(gate, qubits)
    return routine


@build_gate(
    "evolve_1d_dirichlet",
    [float, int, float, int],
    arity=lambda t, discr, eps, trotter: compute_qubit_number_from_considered_points_1d(
        discr - 2
    ),
)
def evolve_1d_dirichlet(
    time: float,
    discretisation_points_number: int,
    epsilon: float,
    trotter_order: int = 1,
    compute_repetitions=compute_r,
) -> QRoutine:
    """Construct a quantum gate solving the wave equation.

    :param time: The time for which we want to solve the wave equation.
    :param discretisation_points_number: The number of points in the 1D line
        discretisation.
    :param epsilon: Desired precision.
    :param trotter_order: The order of the Trotter-Suzuki formula used. An higher
        order will lead to a more precise simulation but also a larger number of
        gates.
    :param compute_repetitions: A function taking the simulation time, the precision,
        the number of matrices in the sum-decomposition, the maximum spectral norm of
        the matrices in the sum-decomposition and the trotter-order used (in this order)
        and returning a number of repetitions r ensuring that the given precision is
        achieved.
    :return: a quantum circuit evolving the given quantum state following the
        wave equation.
    """
    # We need to compute the time for which we will simulate the Hamiltonian now
    # because this time is important to compute r. The evolve_1D_dirichlet_no_repetition
    # will still require the original time because it takes care of computing the real
    # simulation time.
    simulation_time = adapt_dirichlet_evolution_time(time, discretisation_points_number)

    # We have 2 Hamiltonians in our decomposition
    decomposition_size = 2
    # Each Hamiltonian is a unitary matrix of size m, completed by zeros to obtain
    # a size of 2**n. Because of this, each Hamiltonian have m non-zero eigenvalues
    # that are all of unit norm.
    # The spectral norm of each Hamiltonian is then 1.
    max_spectral_norm = 1.0

    r = compute_repetitions(
        simulation_time, epsilon, decomposition_size, max_spectral_norm, trotter_order
    )

    # We solve the wave equation for a time "time/r". As solving the wave equation for
    # a time "time" means simulating the Hamiltonian for a time "simulation_time", we
    # simulate the Hamiltonian for a time "simulation_time / r".
    gate = _evolve_1d_dirichlet_no_repetition_no_time_adjustment(
        simulation_time / r, discretisation_points_number, trotter_order
    )
    routine = QRoutine()
    qubits = routine.new_wires(gate.arity)
    for _ in range(r):
        routine.apply(gate, qubits)
    return routine


def solve_1d_dirichlet_stationary(
    time: float,
    discretisation_points_number: int,
    epsilon: float,
    initial_phi_v: numpy.ndarray = None,
    trotter_order: int = 1,
    linking_set: ty.Optional[ty.List] = None,
    probability_threshold: float = 1e-8,
    imaginary_part_threshold: float = 1e-10,
) -> numpy.ndarray:
    """Construct a quantum circuit solving the wave equation.

    :param time: The time for which we want to solve the wave equation.
    :param discretisation_points_number: The number of points in the 1D line
        discretisation.
    :param epsilon: Desired precision of the simulation.
    :param initial_phi_v: The initial state to give to the wave equation solver.
        The time derivative of the initial state is considered to be zero. A value
        of None (default value) or a null vector (full of 0) means that the
        initialisation step should be skipped.
    :param trotter_order: The order of the Trotter-Suzuki formula used. An higher
        order will lead to a more precise simulation but also a larger number of
        gates.
    :param linking_set: gate implementations to use as a list of AbstractGates with
        implementation. Default to the basic linking set in
        qaths.applications.wave_equation.linkings_sets.basic
    :param probability_threshold: probabilities under this threshold will be ignored.
    :param imaginary_part_threshold: if the 2-norm of imaginary part of the solution
        is above this threshold, throw an exception.
    :return: the result of the wave equation solver
    """
    ignored_points = 2  # The 2 extremities are ignored
    considered_points = discretisation_points_number - ignored_points

    n = compute_qubit_number_from_considered_points_1d(considered_points)

    if linking_set is None:
        linking_set = basic.get_linking_set(n, discretisation_points_number)

    evolve_gate = evolve_1d_dirichlet(
        time, discretisation_points_number, epsilon, trotter_order
    )

    routine = QRoutine()
    input_state = routine.new_wires(n)
    if initial_phi_v is not None and not numpy.isclose(
        numpy.linalg.norm(initial_phi_v), 0
    ):
        routine.apply(initialise_1d_dirichlet_stationary(initial_phi_v), input_state)
    routine.apply(evolve_gate, input_state)

    statevector = simulate_routine(
        routine,
        probability_threshold=probability_threshold,
        link=linking_set,
        box_routines=True,
    )
    solution_without_boundaries = statevector[:considered_points]
    solution = numpy.concatenate(
        (numpy.zeros(1), solution_without_boundaries, numpy.zeros(1))
    )
    # We expect a real solution
    assert numpy.linalg.norm(numpy.imag(solution)) < imaginary_part_threshold, (
        f"The solution should be real (2-norm of the imaginary part is "
        f"{numpy.linalg.norm(numpy.imag(solution))} which is lower than the "
        f"imaginary part threshold '{imaginary_part_threshold}')"
    )
    solution = numpy.real(solution)

    return solution
